from pecos.xmc.xtransformer.matcher import TransformerMatcher
import torch
import numpy as np
import torch.nn as nn
from pecos.utils import torch_util
import os
import argparse

def main():
    parser = argparse.ArgumentParser(description='Generate node features from vanilla Bert')
    parser.add_argument('--dataset', type=str, default="ogbn-arxiv")
    parser.add_argument('--data_root_dir', type=str, default="./dataset")
    parser.add_argument('--save_data_dir', type=str, default="./data_for_XRTransformer")
    args = parser.parse_args()
    print(args)

    # Change args.save_data_dir to args.save_data_dir/args.dataset
    args.save_data_dir = os.path.join(args.save_data_dir, args.dataset)

    matcher = TransformerMatcher.download_model("bert-base-uncased", 2)
    device, n_gpu = torch_util.setup_device(True)
    matcher.to_device(device,n_gpu=n_gpu)
    pred_params = matcher.pred_params.from_dict({"truncate_length": 128,
                                                     "batch_size": n_gpu*64,
                                                     "batch_gen_workers": 64})


    print("Loading input tensor.")
    X_tensor = torch.load(os.path.join(args.save_data_dir,'X.Tokenized.val.pt'))
    print("Input tensor loaded.")
    print("Start generating node features at level "+level)
    _, embedding = matcher.predict(X_tensor, pred_params=pred_params, batch_gen_workers = 64, batch_size = n_gpu*64)
    print("Got node features. Start saving node features")
    np.save('{}/Results/VanillaBert.npy'.format(args.save_data_dir), embedding)
    print("Node features saved.")

if __name__ == "__main__":
    main()
